import pkg_resources
import subprocess

def remove_package(package):
    subprocess.check_call(['pip', 'uninstall', package, "--yes" ])

def install_package(package, version):
    subprocess.check_call(['pip', 'install', f"{package}=={version}", "--user"])

def check_and_install_packages(package_dict):
    for package, version in package_dict.items():
        try:
            pkg_resources.get_distribution(package)
            print(f"{package} is already installed. Removing and reinstalling...")
            remove_package(package)
        except pkg_resources.DistributionNotFound:
            print(f"{package} is not installed.")
        
        install_package(package, version)
        print(f"{package} installed successfully.")

# Dictionary of packages and their versions
package_dict = {
    "Rtree": "1.0.1",
    "seaborn": "0.12.2",
    "pathlib": "1.0.1",
    "geopandas": "0.13.2",
    "contextily": "1.3.0",
    "shapely": "2.0.1",
    "numpy": "1.24.0",
    "pandas": "2.0.3",
    "Pillow": "10.0.0",
}
# Check and install packages if necessary
check_and_install_packages(package_dict)


# %%
import os
import numpy as np
import pandas as pd
from PIL import Image
from pathlib import Path
import seaborn as sns
import geopandas as gpd
import contextily as cx
import matplotlib
import matplotlib.pyplot as plt
font = {'weight': 'normal',
        'size': 10}
plt.rc('font', **font)
plt.rcParams['figure.figsize'] = (10, 10)
matplotlib.style.use(['seaborn-talk', 'seaborn-ticks', 'seaborn-whitegrid'])

# Get the current file path
file_path = os.path.abspath(__file__)

# Get the directory containing the file
directory = os.path.dirname(file_path)

# Set the working directory to the file's directory
os.chdir(directory)

# plt.style.use("dark_background") # dark bg plots
sns.set(style="ticks", context="talk")

# Import blocks
block2 = gpd.read_parquet("BLOCKS_sch_coded.spq")
# Import states
state = gpd.read_file( f"IND_adm1.shp")

# %%
block2['sched_str'] = np.where(block2.sch == 1, "Scheduled", "Non-Scheduled")
block2['sched_str'].value_counts()
# %% # # Plot Treatment
xmin, ymin, xmax, ymax = block2.total_bounds
f, ax = plt.subplots( 1, 
                     figsize = ( 12, 12 ), 
                     dpi = 100 )

# Generate 
block2.plot( column = 'sched_str', 
            categorical = True, 
            legend = True, alpha = 0.8,
            cmap = 'Set1', edgecolor = 'k', 
            linewidth = 0.3, ax = ax )
state.plot( facecolor = 'none', 
           categorical = True, 
           legend = True,
           edgecolor = 'y', 
           linewidth = 1, 
           ax = ax );

# Get the legend object
legend = ax.get_legend()

# Set the font size of the legend labels
legend.get_texts()[0].set_fontsize(20)
legend.get_texts()[1].set_fontsize(20)

# Define limits
ax.set_xlim(xmin, xmax);
ax.set_ylim(ymin, ymax);
ax.set_axis_off();
cx.add_basemap(ax, 
               crs = state.crs.to_string(),
               source = cx.providers.Stamen.TonerLite );
ax.set_axis_off();


f.savefig( 'main_figure1.png', dpi = 400 )
image_1 = Image.open(r'main_figure1.png')
im_1 = image_1.convert('RGB')
im_1.save(r'main_figure1.pdf')
f.savefig( 'main_figure1.pdf', dpi = 10 )